#!/usr/bin/env python3
"""
Phase 5: Endogenous Regime Identification
Tables 14-16: Smooth Transition Regression, k-means clustering, Hansen threshold test.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import optimize, stats as sp_stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
FIG_DIR = CCA_DIR / "output" / "figures"
TABLE_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    resid = y - X @ gls.beta
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
        'ssr': float(np.sum(resid ** 2)),
    }

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 5: ENDOGENOUS REGIME IDENTIFICATION")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 14: SMOOTH TRANSITION REGRESSION (STR)
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 14: SMOOTH TRANSITION REGRESSION")
print("=" * 70)

gov_sample = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
print(f"  Governance sample: {gov_sample['iso3'].nunique()} countries, {len(gov_sample):,} obs")

# STR: y = X*β₁ + G(gov; γ, c) * X*β₂ + ε
# G(gov; γ, c) = 1 / (1 + exp(-γ*(gov - c)))
# We approximate by grid search over c and γ

gov_vals = gov_sample['governance_composite'].dropna()
c_grid = np.percentile(gov_vals, np.arange(15, 86, 5))
gamma_grid = [1, 2, 5, 10, 20, 50]

best_r2 = -np.inf
best_params = None
str_results = []

for c in c_grid:
    for gamma in gamma_grid:
        # Compute transition function
        gov_sample['G'] = 1.0 / (1.0 + np.exp(-gamma * (gov_sample['governance_composite'] - c)))

        # Create regime-specific variables: Z₁*(1-G) = regime 1, Z₁*G = regime 2
        gov_sample['Z_1_regime1'] = gov_sample['Z_1'] * (1 - gov_sample['G'])
        gov_sample['Z_1_regime2'] = gov_sample['Z_1'] * gov_sample['G']

        str_vars = ['Z_1_regime1', 'Z_1_regime2', 'Z_2', 'Z_3'] + baseline_controls
        r = run_gls(gov_sample, 'ca_gdp', str_vars)
        if r is None:
            continue

        str_results.append({
            'c': c,
            'gamma': gamma,
            'Z1_regime1': r['coefficients']['Z_1_regime1'],
            'Z1_regime1_pval': r['p_values']['Z_1_regime1'],
            'Z1_regime2': r['coefficients']['Z_1_regime2'],
            'Z1_regime2_pval': r['p_values']['Z_1_regime2'],
            'r_squared': r['r_squared'],
            'n_obs': r['n_obs'],
        })

        if r['r_squared'] > best_r2:
            best_r2 = r['r_squared']
            best_params = (c, gamma)

str_df = pd.DataFrame(str_results)
str_df.to_csv(TABLE_DIR / "table14_str_grid.csv", index=False)

if best_params:
    best_c, best_gamma = best_params
    best_row = str_df[(str_df['c'] == best_c) & (str_df['gamma'] == best_gamma)].iloc[0]
    print(f"\n  Best STR: c={best_c:.2f}, γ={best_gamma}")
    print(f"  Regime 1 (low gov): Z₁={best_row['Z1_regime1']:.2f}{star(best_row['Z1_regime1_pval'])}")
    print(f"  Regime 2 (high gov): Z₁={best_row['Z1_regime2']:.2f}{star(best_row['Z1_regime2_pval'])}")
    print(f"  R²={best_r2:.4f}")

    # Report top 5 by R²
    top5 = str_df.nlargest(5, 'r_squared')
    print(f"\n  Top 5 STR specifications:")
    for _, row in top5.iterrows():
        print(f"    c={row['c']:.2f}, γ={row['gamma']:.0f}: "
              f"R1={row['Z1_regime1']:.2f}{star(row['Z1_regime1_pval'])}, "
              f"R2={row['Z1_regime2']:.2f}{star(row['Z1_regime2_pval'])}, "
              f"R²={row['r_squared']:.4f}")

# Markdown for best STR
md = ["# Table 14: Smooth Transition Regression (Best 10 by R²)\n"]
md.append("| c (threshold) | γ (speed) | Z₁ Regime 1 | p-val | Z₁ Regime 2 | p-val | R² |")
md.append("|---|---|---|---|---|---|---|")
for _, row in str_df.nlargest(10, 'r_squared').iterrows():
    md.append(f"| {row['c']:.2f} | {row['gamma']:.0f} | "
              f"{row['Z1_regime1']:.2f}{star(row['Z1_regime1_pval'])} | {row['Z1_regime1_pval']:.4f} | "
              f"{row['Z1_regime2']:.2f}{star(row['Z1_regime2_pval'])} | {row['Z1_regime2_pval']:.4f} | "
              f"{row['r_squared']:.4f} |")
(TABLE_DIR / "table14_str.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 15: K-MEANS CLUSTERING
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 15: K-MEANS CLUSTERING")
print("=" * 70)

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

cluster_vars = ['governance_composite', 'kaopen']
cluster_sample = est.dropna(subset=base_vars + ['ca_gdp'] + cluster_vars).copy()

# Compute country-level means for clustering
country_means = cluster_sample.groupby('iso3')[cluster_vars].mean().dropna()
print(f"  Countries with clustering data: {len(country_means)}")

scaler = StandardScaler()
X_cluster = scaler.fit_transform(country_means.values)

# Try k=2, 3, 4
cluster_rows = []
for k in [2, 3, 4]:
    km = KMeans(n_clusters=k, random_state=42, n_init=10)
    labels = km.fit_predict(X_cluster)
    country_means[f'cluster_k{k}'] = labels

    # Map back to panel
    cluster_map = dict(zip(country_means.index, labels))
    cluster_sample[f'cluster'] = cluster_sample['iso3'].map(cluster_map)

    print(f"\n  k={k} clusters:")
    for cl in range(k):
        cl_countries = country_means[country_means[f'cluster_k{k}'] == cl].index.tolist()
        sub = cluster_sample[cluster_sample['cluster'] == cl]
        n_c = sub['iso3'].nunique()
        n_cca = len(set(cl_countries) & set(['ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ',
                                              'MDA', 'MNG', 'RUS', 'TJK', 'TKM', 'UKR', 'UZB']))
        mean_gov = country_means.loc[cl_countries, 'governance_composite'].mean()
        mean_kaopen = country_means.loc[cl_countries, 'kaopen'].mean()

        r = run_gls(sub, 'ca_gdp', base_vars)
        if r is None:
            print(f"    Cluster {cl}: {n_c} countries (insufficient obs)")
            continue

        row = {
            'k': k,
            'cluster': cl,
            'n_countries': n_c,
            'n_cca': n_cca,
            'mean_governance': mean_gov,
            'mean_kaopen': mean_kaopen,
            'n_obs': r['n_obs'],
            'r_squared': r['r_squared'],
        }
        for v in demo_vars:
            row[f'{v}_coef'] = r['coefficients'][v]
            row[f'{v}_pval'] = r['p_values'][v]
        cluster_rows.append(row)
        print(f"    Cluster {cl}: {n_c} countries ({n_cca} CCA), "
              f"gov={mean_gov:.2f}, kaopen={mean_kaopen:.2f}, "
              f"Z₁={row['Z_1_coef']:.2f}{star(row['Z_1_pval'])}")

cluster_df = pd.DataFrame(cluster_rows)
cluster_df.to_csv(TABLE_DIR / "table15_kmeans.csv", index=False)

md = ["# Table 15: K-Means Clustering Results\n"]
md.append("| k | Cluster | N_c | N_CCA | Mean gov | Mean KAOPEN | Z₁ | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|---|")
for _, row in cluster_df.iterrows():
    md.append(f"| {row['k']:.0f} | {row['cluster']:.0f} | {row['n_countries']} | {row['n_cca']:.0f} | "
              f"{row['mean_governance']:.2f} | {row['mean_kaopen']:.2f} | "
              f"{row['Z_1_coef']:.2f}{star(row['Z_1_pval'])} | {row['Z_1_pval']:.4f} | "
              f"{row['r_squared']:.4f} |")
(TABLE_DIR / "table15_kmeans.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 16: HANSEN THRESHOLD TEST
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 16: HANSEN THRESHOLD TEST (governance)")
print("=" * 70)

# Hansen (2000) threshold test: search for governance threshold that minimizes SSR
# For each candidate threshold, split sample and compute SSR₁ + SSR₂
# Compare to pooled SSR via F-test

gov_sample_h = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()
r_pooled = run_gls(gov_sample_h, 'ca_gdp', base_vars)
ssr_pooled = r_pooled['ssr'] if r_pooled else np.inf

# Grid of thresholds (trim 15% from each tail)
gov_sorted = gov_sample_h['governance_composite'].dropna().sort_values()
trim = int(0.15 * len(gov_sorted))
threshold_grid = np.percentile(gov_sorted.values[trim:-trim], np.arange(5, 96, 2))

hansen_rows = []
for tau in threshold_grid:
    below = gov_sample_h[gov_sample_h['governance_composite'] <= tau]
    above = gov_sample_h[gov_sample_h['governance_composite'] > tau]

    r_below = run_gls(below, 'ca_gdp', base_vars)
    r_above = run_gls(above, 'ca_gdp', base_vars)

    if r_below is None or r_above is None:
        continue

    ssr_1 = r_below['ssr']
    ssr_2 = r_above['ssr']
    n1 = r_below['n_obs']
    n2 = r_above['n_obs']
    k = len(base_vars)

    # Chow-type F statistic
    F_stat = ((ssr_pooled - ssr_1 - ssr_2) / k) / ((ssr_1 + ssr_2) / (n1 + n2 - 2 * k))
    if n1 + n2 - 2 * k > 0:
        p_val = 1 - sp_stats.f.cdf(F_stat, k, n1 + n2 - 2 * k)
    else:
        p_val = np.nan

    hansen_rows.append({
        'threshold': tau,
        'n_below': n1,
        'n_above': n2,
        'nc_below': r_below['n_countries'],
        'nc_above': r_above['n_countries'],
        'ssr_below': ssr_1,
        'ssr_above': ssr_2,
        'ssr_total': ssr_1 + ssr_2,
        'F_stat': F_stat,
        'p_value': p_val,
        'Z1_below': r_below['coefficients']['Z_1'],
        'Z1_below_pval': r_below['p_values']['Z_1'],
        'Z1_above': r_above['coefficients']['Z_1'],
        'Z1_above_pval': r_above['p_values']['Z_1'],
        'r2_below': r_below['r_squared'],
        'r2_above': r_above['r_squared'],
    })

hansen_df = pd.DataFrame(hansen_rows)
hansen_df.to_csv(TABLE_DIR / "table16_hansen_threshold.csv", index=False)

# Find optimal threshold (minimizes total SSR)
if len(hansen_df) > 0:
    best_idx = hansen_df['ssr_total'].idxmin()
    best = hansen_df.loc[best_idx]
    print(f"\n  Optimal threshold: governance = {best['threshold']:.3f}")
    print(f"  Below ({best['nc_below']:.0f}c, {best['n_below']:.0f} obs): "
          f"Z₁={best['Z1_below']:.2f}{star(best['Z1_below_pval'])}")
    print(f"  Above ({best['nc_above']:.0f}c, {best['n_above']:.0f} obs): "
          f"Z₁={best['Z1_above']:.2f}{star(best['Z1_above_pval'])}")
    print(f"  F-stat={best['F_stat']:.2f}, p={best['p_value']:.4f}")

    # Confidence interval: all thresholds where LR stat < critical value
    # LR(τ) = n * (SSR(τ) - SSR(τ*)) / SSR(τ*)
    ssr_star = best['ssr_total']
    n_total = r_pooled['n_obs']
    hansen_df['LR'] = n_total * (hansen_df['ssr_total'] - ssr_star) / ssr_star
    # 95% CI: LR < 7.35 (Hansen 2000 asymptotic critical value for single threshold)
    ci = hansen_df[hansen_df['LR'] <= 7.35]
    if len(ci) > 0:
        print(f"  95% CI: [{ci['threshold'].min():.3f}, {ci['threshold'].max():.3f}]")

# Markdown: best + top 5
md = ["# Table 16: Hansen Threshold Test (Governance)\n"]
md.append(f"Pooled SSR: {ssr_pooled:.2f}, N={r_pooled['n_obs']}\n")
md.append("| Threshold | N below | N above | Z₁ below | p-val | Z₁ above | p-val | F | p(F) |")
md.append("|---|---|---|---|---|---|---|---|---|")
for _, row in hansen_df.nsmallest(10, 'ssr_total').iterrows():
    md.append(f"| {row['threshold']:.3f} | {row['n_below']:.0f} | {row['n_above']:.0f} | "
              f"{row['Z1_below']:.2f}{star(row['Z1_below_pval'])} | {row['Z1_below_pval']:.4f} | "
              f"{row['Z1_above']:.2f}{star(row['Z1_above_pval'])} | {row['Z1_above_pval']:.4f} | "
              f"{row['F_stat']:.2f} | {row['p_value']:.4f} |")
(TABLE_DIR / "table16_hansen_threshold.md").write_text("\n".join(md))

# Figure: SSR as function of threshold
if len(hansen_df) > 2:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

    ax1.plot(hansen_df['threshold'], hansen_df['ssr_total'], 'b-')
    if best_params:
        ax1.axvline(best['threshold'], color='r', linestyle='--', label=f"Optimal: {best['threshold']:.2f}")
    ax1.set_xlabel("Governance threshold")
    ax1.set_ylabel("Total SSR")
    ax1.set_title("Hansen Threshold Test: SSR by Governance Threshold")
    ax1.legend()

    ax2.plot(hansen_df['threshold'], hansen_df['Z1_below'], 'b-o', markersize=3, label='Z₁ (below threshold)')
    ax2.plot(hansen_df['threshold'], hansen_df['Z1_above'], 'r-s', markersize=3, label='Z₁ (above threshold)')
    ax2.axhline(0, color='k', linestyle='--', linewidth=0.5)
    ax2.set_xlabel("Governance threshold")
    ax2.set_ylabel("Z₁ coefficient")
    ax2.set_title("Z₁ Coefficient by Regime")
    ax2.legend()

    fig.tight_layout()
    fig.savefig(FIG_DIR / "fig2_hansen_threshold.png", dpi=150, bbox_inches="tight")
    plt.close()
    print(f"\n  Saved figure: {FIG_DIR / 'fig2_hansen_threshold.png'}")

print(f"\nAll Phase 5 tables saved to {TABLE_DIR}")
print("Phase 5 complete.")
